import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
import os
import numpy as np
import collections
from tqdm import tqdm


np.random.seed(2191)  # for reproducibility

# LAMBDA FUNCTIONS
filenameToPILImage = lambda x: Image.open(x)

class TieredImagenetOneShotDataset(data.Dataset):
    def __init__(self, dataroot='/path/to/tiered_imagenet', type='train',
                 nEpisodes=1000, classes_per_set=10, samples_per_class=1, 
                 samples_per_query=5, ImageSize=224):

        self.PiLImageResize = lambda x: x.resize((ImageSize, ImageSize))

        self.ImageSize = ImageSize
        self.nEpisodes = nEpisodes
        self.classes_per_set = classes_per_set
        self.samples_per_class = samples_per_class
        self.samples_per_query = samples_per_query
        self.n_samples = self.samples_per_class * self.classes_per_set
        self.n_samplesNShot = self.samples_per_query * self.classes_per_set
        
        # Transformations to the image
        self.transform = transforms.Compose([
            filenameToPILImage,
            self.PiLImageResize,
            transforms.ToTensor()
        ])

        # Load data from folder structure
        self.data_dir = os.path.join(dataroot, type)
        self.data = self.load_from_folders(self.data_dir)
        self.data = collections.OrderedDict(sorted(self.data.items()))
        self.classes_dict = {class_name: i for i, class_name in enumerate(self.data.keys())}
        
        self.create_episodes(self.nEpisodes)

    def load_from_folders(self, data_dir):
        """
        Load dataset from folder structure where each subfolder represents a class
        and contains the images for that class.
        """
        data_dict = {}
        class_folders = [d for d in os.listdir(data_dir) 
                        if os.path.isdir(os.path.join(data_dir, d))]
        
        for class_name in class_folders:
            class_dir = os.path.join(data_dir, class_name)
            image_files = [os.path.join(class_name, f) for f in os.listdir(class_dir) 
                          if f.endswith('.jpg') or f.endswith('.png')]
            data_dict[class_name] = image_files
        
        return data_dict

    def create_episodes(self, episodes):
        self.support_set_x_batch = []
        self.target_x_batch = []
        
        for b in np.arange(episodes):
            # Select n classes_per_set randomly
            selected_classes = np.random.choice(len(self.data.keys()), 
                                              self.classes_per_set, 
                                              False)
            
            support_set_x = []
            target_x = []
            
            for c in selected_classes:
                # Get all available samples for this class
                class_samples = self.data[list(self.data.keys())[c]]
                
                # Randomly select samples for support set
                support_samples = np.random.choice(len(class_samples), 
                                                  self.samples_per_class, 
                                                  False)
                support_set_x.append(np.array(class_samples)[support_samples].tolist())
                
                # Select remaining samples for target set (make sure we don't reuse support samples)
                remaining_samples = [i for i in range(len(class_samples)) 
                                    if i not in support_samples]
                if len(remaining_samples) >= self.samples_per_query:
                    target_samples = np.random.choice(remaining_samples, 
                                                  self.samples_per_query, 
                                                  False)
                    target_x.append(np.array(class_samples)[target_samples].tolist())
                else:
                    # If not enough remaining samples, just take what's available
                    target_x.append(np.array(class_samples)[remaining_samples].tolist())
            
            self.support_set_x_batch.append(support_set_x)
            self.target_x_batch.append(target_x)

    def __getitem__(self, index):
        support_set_x = torch.FloatTensor(self.n_samples, 3, self.ImageSize, self.ImageSize)
        support_set_y = np.zeros((self.n_samples), dtype=np.int32)
        target_x = torch.FloatTensor(self.n_samplesNShot, 3, self.ImageSize, self.ImageSize)
        target_y = np.zeros((self.n_samplesNShot), dtype=np.int32)

        flatten_support_set_x_batch = [os.path.join(self.data_dir, item)
                                     for sublist in self.support_set_x_batch[index] for item in sublist]
        support_set_y = np.array([self.classes_dict[os.path.dirname(item)]
                                for sublist in self.support_set_x_batch[index] for item in sublist])
        flatten_target_x = [os.path.join(self.data_dir, item)
                          for sublist in self.target_x_batch[index] for item in sublist]
        target_y = np.array([self.classes_dict[os.path.dirname(item)]
                           for sublist in self.target_x_batch[index] for item in sublist])

        for i, path in enumerate(flatten_support_set_x_batch):
            if self.transform is not None:
                support_set_x[i] = self.transform(path)

        for i, path in enumerate(flatten_target_x):
            if self.transform is not None:
                target_x[i] = self.transform(path)

        # convert the targets number between [0, self.classes_per_set)
        classes_dict_temp = {np.unique(support_set_y)[i]: i for i in np.arange(len(np.unique(support_set_y)))}
        support_set_y = np.array([classes_dict_temp[i] for i in support_set_y])
        target_y = np.array([classes_dict_temp[i] for i in target_y])

        return support_set_x, torch.LongTensor(support_set_y), target_x, torch.LongTensor(target_y)

    def __len__(self):
        return self.nEpisodes




def get_tiered_base_class_images(root_dir, n_base, experts=1, image_size=224, overlap=0, num_per_base = 100):
    """
    Generate base class images from the tieredImagenet training dataset (folder structure).
    
    Args:
        root_dir (str): Path to the dataset root directory
        n_base (int): Number of classes to select
        experts (int): Number of experts (multiplies the number of base classes)
        image_size (int): Size to resize images to (default: 224)
        overlap (int): Number of overlapping classes between experts (default: 0)
        
    Returns:
        tuple: (class_images, class_labels) where:
            class_images: torch.Tensor of shape [n_way, n_images, 3, image_size, image_size]
            class_labels: list of class names corresponding to each class
    """
    # Define the training directory path
    train_dir = os.path.join(root_dir, 'train')
    
    # Create a dictionary mapping class names to image paths
    class_dict = collections.defaultdict(list)
    class_folders = [d for d in os.listdir(train_dir) 
                    if os.path.isdir(os.path.join(train_dir, d))]
    
    for class_name in class_folders:
        class_dir = os.path.join(train_dir, class_name)
        image_files = [os.path.join(class_dir, f) for f in os.listdir(class_dir) 
                      if f.endswith('.jpg') or f.endswith('.png')]
        class_dict[class_name] = image_files
    
    # Get list of all classes and randomly select n_base * experts classes
    class_list = list(class_dict.keys())
    selected_classes = np.random.choice(class_list, n_base * experts, replace=False)
    
    # Setup transforms with normalization
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])
    ])
    
    # Collect all images for each selected class
    class_images = []
    class_labels = []
    
    for class_name in selected_classes:
        # Get all images for this class (limit to 100 per class)
                # Pad the image list with duplicates if less than 100 images
        if len(class_dict[class_name]) < num_per_base:
            original_length = len(class_dict[class_name])
            needed = num_per_base - original_length
            duplicates = np.random.choice(class_dict[class_name], needed, replace=True)
            class_dict[class_name].extend(duplicates)
            print(f"Class {class_name} padded from {original_length} to {num_per_base} images")
        #print(len(class_dict[class_name]))
        #print(sed)
        image_paths = class_dict[class_name][:num_per_base]
        class_imgs = []
        
        # Load and transform each image
        for img_path in image_paths:
            try:
                img = Image.open(img_path).convert('RGB')
                if transform is not None:
                    img = transform(img)
                class_imgs.append(img)
            except:
                print(f"Error loading image: {img_path}")
                continue
        
        # Skip classes with no valid images
        if len(class_imgs) == 0:
            continue
            
        # Stack images for this class and add to the list
        class_images.append(torch.stack(class_imgs))
        class_labels.append(class_name)
    
    # Stack all class images into a single tensor
    if len(class_images) > 0:
        class_images = torch.stack(class_images)
    else:
        raise ValueError("No valid images found in the selected classes")
    
    return class_images, class_labels

if __name__ == '__main__':
    base_images, base_labels = get_tiered_base_class_images(
        root_dir='/home/lyc/lyc_vdb/datasets/tiered_imagenet',
        n_base=100,
        experts=1,
        image_size=224
    )

    print(f"Generated base class images tensor with shape: {base_images.shape}")
    print(f"Number of classes: {len(base_labels)}")
    print(f"Example class names: {base_labels[:5]}")